Laplace Redux Redux

Severin Bratus

Theoretic Introduction

Importance of uncertainty calibration

Medical diagnosis

Autonomous driving

Predictive justice

Finance

What is well-calibrated?

The Bayesian approach

posterior \(p(\theta \mid \mathcal{D}) = \tfrac{1}{Z} \,p(\mathcal{D} \mid \theta) \, p(\theta)\)
likelihood \(p(\mathcal{D} \mid \theta)\)
prior \(p(\theta)\)
evidence \(Z := p(\mathcal{D}) = \textstyle\int p(\mathcal{D} \mid \theta) \, p(\theta) \,d\theta\)

loss = – log(likelihood)

min loss = max likelihood = max posterior (if prior uniform)

  • MLE = maximum likelihood estimate
  • MAP = maximum a posteriori

A physical metaphor

The Laplace approximation

posterior \(p(\theta \mid \mathcal{D}) \approx \mathcal{N}(\theta; \mu, \varSigma)\)
centered at \(\mu := \theta_\text{MAP}\)
with covariance \(\varSigma := H^{-1}\)
where \(H = \nabla^2_\theta \mathcal{L}(\mathcal{D};\theta) \vert_{\theta_\text{MAP}}\)

The posterior predictive

Probability of \(y\) given that the model predicted \(f(x_*)\) on input \(x_*\).

\[ p(y \mid f(x_*), \mathcal{D}) = \int p(y \mid f_\theta(x_*)) \, p(\theta \mid \mathcal{D}) \,d\theta \]

The Hessian

\[ (H_{f})_{i,j}={\frac {\partial ^{2}f}{\partial x_{i}\,\partial x_{j}}} \]

\[ H_{f} = \begin{bmatrix} {\dfrac {\partial ^{2}f}{\partial x_{1}^{2}}}&\cdots &{\dfrac {\partial ^{2}f}{\partial x_{1}\,\partial x_{n}}}\\[2.2ex] \vdots &\ddots &\vdots \\[2.2ex] {\dfrac {\partial ^{2}f}{\partial x_{n}\,\partial x_{1}}}&\cdots &{\dfrac {\partial ^{2}f}{\partial x_{n}^{2}}} \end{bmatrix} \]

The Fisher information matrix

\[ F := \textstyle\sum_{n=1}^N \mathbb{E}_{\widehat{y} \sim p(y \mid f_\theta(x_n))} \left[ gg^\intercal \right] \\ g = \nabla_\theta \log p(\widehat{y} \mid f_\theta(x_n)) \large\vert_{\theta_\text{MAP}} \]

The generalized Gauss-Newton

\[ G := \textstyle\sum_{n=1}^N J(x_n) \left( \nabla^2_{f} \log p(y_n \mid f) \Large\vert_{f=f_{\theta_\text{map}}(x_n)} \right) J(x_n)^\intercal \\ J(x_n) := \nabla_\theta f_\theta(x_n) \vert_{\theta_\text{map}} \]

Weight subsets

Approximate Hessian structures

LaplaceRedux.jl

TODO describe task & network for demo

True Hessian

theta, rebuild = Flux.destructure(nn)

function loss_vec(theta::Vector)
    nn_rebuilt = rebuild(nn)
    Flux.Losses.logitcrossentropy(nn_rebuilt(X), Y)
end;

H = Zygote.hessian(loss_vec, theta)

True Hessian

Generalized Gauss-Newton (GNN)

GGN error

Last layer only

Block-diagonal Laplace

Last layer only

Our contributions

  • Laplace on multi-class classification
  • Generalized Gauss-Newton
  • Batched computations
  • Block-diagonal methods
  • MLJ.jl interface

The good, the bad, and the ugly

What we found nice

  • Metaprogramming
  • Julia standard API
  • Flux/Zygote

Pain points

  • Compile & load times
  • Obscure stack traces
  • Limited LSP & Unicode support for Jupyter Lab
  • Zygote not self-differentiable
  • No second-order information from Zygote
  • No branch coverage
  • No ONNX

Acknowledgements

  • Team:
    • Mark Ardman
    • Severin Bratus
    • Adelina Cazacu
    • Andrei Ionescu
    • Ivan Makarov
  • Patrick Altmeyer
  • CSE2000 Software Project course @ TU Delft

References

Amini, Alexander, Ava Soleimany, Sertac Karaman, and Daniela Rus. 2019. “Spatial Uncertainty Sampling for End-to-End Control.” https://arxiv.org/abs/1805.04829.
Daxberger, Erik, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. 2022. “Laplace Redux – Effortless Bayesian Deep Learning.” https://arxiv.org/abs/2106.14806.
Martens, James, and Roger Grosse. 2020. “Optimizing Neural Networks with Kronecker-factored Approximate Curvature.” https://arxiv.org/abs/1503.05671.

* stock images generated by Stable Diffusion

Q&A

s bratus [at] student tudelft nl